{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# JaX Flax --> ONNX\n",
    "\n",
    "This notebook converts brax MLP networks to an ONNX checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
    "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from brax.training.agents.ppo import networks as ppo_networks\n",
    "from mujoco_playground.config import locomotion_params, manipulation_params\n",
    "from mujoco_playground import locomotion, manipulation\n",
    "import functools\n",
    "import pickle\n",
    "import jax.numpy as jp\n",
    "import jax\n",
    "import tf2onnx\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "import onnxruntime as rt\n",
    "from brax.training.acme import running_statistics\n",
    "from brax.training.checkpoint import load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"BerkeleyHumanoidJoystickFlatTerrain\"\n",
    "# ppo_params = locomotion_params.brax_ppo_config(env_name)\n",
    "ppo_params = locomotion_params.brax_ppo_config(env_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def identity_observation_preprocessor(observation, preprocessor_params):\n",
    "  del preprocessor_params\n",
    "  return observation\n",
    "\n",
    "network_factory=functools.partial(\n",
    "  ppo_networks.make_ppo_networks,\n",
    "  **ppo_params.network_factory,\n",
    "  # We need to explicitly call the normalization function here since only the brax\n",
    "  # PPO train.py script creates it if normalize_observations is True.\n",
    "  preprocess_observations_fn=running_statistics.normalize,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env_cfg = locomotion.get_default_config(env_name)\n",
    "# env = locomotion.load(env_name, config=env_cfg)\n",
    "env_cfg = locomotion.get_default_config(env_name)\n",
    "env = locomotion.load(env_name, config=env_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_size = env.observation_size\n",
    "act_size = env.action_size\n",
    "print(obs_size, act_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ppo_network = network_factory(obs_size, act_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = \"/home/kevin/mujoco_playground/mujoco_playground/experimental/learning/checkpoints/BerkeleyHumanoidJoystickFlatTerrain-20250111-001442\"\n",
    "\n",
    "params = load(ckpt_path)\n",
    "\n",
    "output_path = f\"bh_policy.onnx\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = (params[0], params[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "make_inference_fn = ppo_networks.make_inference_fn(ppo_network)\n",
    "inference_fn = make_inference_fn(params, deterministic=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(tf.keras.Model):\n",
    "    def __init__(\n",
    "        self,\n",
    "        layer_sizes,\n",
    "        activation=tf.nn.relu,\n",
    "        kernel_init=\"lecun_uniform\",\n",
    "        activate_final=False,\n",
    "        bias=True,\n",
    "        layer_norm=False,\n",
    "        mean_std=None,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layer_sizes = layer_sizes\n",
    "        self.activation = activation\n",
    "        self.kernel_init = kernel_init\n",
    "        self.activate_final = activate_final\n",
    "        self.bias = bias\n",
    "        self.layer_norm = layer_norm\n",
    "\n",
    "        if mean_std is not None:\n",
    "            self.mean = tf.Variable(mean_std[0], trainable=False, dtype=tf.float32)\n",
    "            self.std = tf.Variable(mean_std[1], trainable=False, dtype=tf.float32)\n",
    "        else:\n",
    "            self.mean = None\n",
    "            self.std = None\n",
    "\n",
    "        self.mlp_block = tf.keras.Sequential(name=\"MLP_0\")\n",
    "        for i, size in enumerate(self.layer_sizes):\n",
    "            dense_layer = layers.Dense(\n",
    "                size,\n",
    "                activation=self.activation,\n",
    "                kernel_initializer=self.kernel_init,\n",
    "                name=f\"hidden_{i}\",\n",
    "                use_bias=self.bias,\n",
    "            )\n",
    "            self.mlp_block.add(dense_layer)\n",
    "            if self.layer_norm:\n",
    "                self.mlp_block.add(layers.LayerNormalization(name=f\"layer_norm_{i}\"))\n",
    "        if not self.activate_final and self.mlp_block.layers:\n",
    "            if hasattr(self.mlp_block.layers[-1], 'activation') and self.mlp_block.layers[-1].activation is not None:\n",
    "                self.mlp_block.layers[-1].activation = None\n",
    "\n",
    "        self.submodules = [self.mlp_block]\n",
    "\n",
    "    def call(self, inputs):\n",
    "        if isinstance(inputs, list):\n",
    "            inputs = inputs[0]\n",
    "        if self.mean is not None and self.std is not None:\n",
    "            print(self.mean.shape, self.std.shape)\n",
    "            inputs = (inputs - self.mean) / self.std\n",
    "        logits = self.mlp_block(inputs)\n",
    "        loc, _ = tf.split(logits, 2, axis=-1)\n",
    "        return tf.tanh(loc)\n",
    "\n",
    "def make_policy_network(\n",
    "    param_size,\n",
    "    mean_std,\n",
    "    hidden_layer_sizes=[256, 256],\n",
    "    activation=tf.nn.relu,\n",
    "    kernel_init=\"lecun_uniform\",\n",
    "    layer_norm=False,\n",
    "):\n",
    "    policy_network = MLP(\n",
    "        layer_sizes=list(hidden_layer_sizes) + [param_size],\n",
    "        activation=activation,\n",
    "        kernel_init=kernel_init,\n",
    "        layer_norm=layer_norm,\n",
    "        mean_std=mean_std,\n",
    "    )\n",
    "    return policy_network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = params[0].mean[\"state\"]\n",
    "std = params[0].std[\"state\"]\n",
    "\n",
    "# Convert mean/std jax arrays to tf tensors.\n",
    "mean_std = (tf.convert_to_tensor(mean), tf.convert_to_tensor(std))\n",
    "\n",
    "tf_policy_network = make_policy_network(\n",
    "    param_size=act_size * 2,\n",
    "    mean_std=mean_std,\n",
    "    hidden_layer_sizes=ppo_params.network_factory.policy_hidden_layer_sizes,\n",
    "    activation=tf.nn.swish,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example_input = tf.zeros((1, obs_size[\"state\"][0]))\n",
    "example_output = tf_policy_network(example_input)\n",
    "print(example_output.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "def transfer_weights(jax_params, tf_model):\n",
    "    \"\"\"\n",
    "    Transfer weights from a JAX parameter dictionary to the TensorFlow model.\n",
    "\n",
    "    Parameters:\n",
    "    - jax_params: dict\n",
    "      Nested dictionary with structure {block_name: {layer_name: {params}}}.\n",
    "      For example:\n",
    "      {\n",
    "        'CNN_0': {\n",
    "          'Conv_0': {'kernel': np.ndarray},\n",
    "          'Conv_1': {'kernel': np.ndarray},\n",
    "          'Conv_2': {'kernel': np.ndarray},\n",
    "        },\n",
    "        'MLP_0': {\n",
    "          'hidden_0': {'kernel': np.ndarray, 'bias': np.ndarray},\n",
    "          'hidden_1': {'kernel': np.ndarray, 'bias': np.ndarray},\n",
    "          'hidden_2': {'kernel': np.ndarray, 'bias': np.ndarray},\n",
    "        }\n",
    "      }\n",
    "\n",
    "    - tf_model: tf.keras.Model\n",
    "      An instance of the adapted VisionMLP model containing named submodules and layers.\n",
    "    \"\"\"\n",
    "    for layer_name, layer_params in jax_params.items():\n",
    "        try:\n",
    "            tf_layer = tf_model.get_layer(\"MLP_0\").get_layer(name=layer_name)\n",
    "        except ValueError:\n",
    "            print(f\"Layer {layer_name} not found in TensorFlow model.\")\n",
    "            continue\n",
    "        if isinstance(tf_layer, tf.keras.layers.Dense):\n",
    "            kernel = np.array(layer_params['kernel'])\n",
    "            bias = np.array(layer_params['bias'])\n",
    "            print(f\"Transferring Dense layer {layer_name}, kernel shape {kernel.shape}, bias shape {bias.shape}\")\n",
    "            tf_layer.set_weights([kernel, bias])\n",
    "        else:\n",
    "            print(f\"Unhandled layer type in {layer_name}: {type(tf_layer)}\")\n",
    "\n",
    "    print(\"Weights transferred successfully.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_weights(params[1]['params'], tf_policy_network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example inputs for the model\n",
    "test_input = [np.ones((1, obs_size[\"state\"][0]), dtype=np.float32)]\n",
    "\n",
    "# Define the TensorFlow input signature\n",
    "spec = [tf.TensorSpec(shape=(1, obs_size[\"state\"][0]), dtype=tf.float32, name=\"obs\")]\n",
    "\n",
    "tensorflow_pred = tf_policy_network(test_input)[0]\n",
    "# Build the model by calling it with example data\n",
    "print(f\"Tensorflow prediction: {tensorflow_pred}\")\n",
    "\n",
    "tf_policy_network.output_names = ['continuous_actions']\n",
    "\n",
    "# opset 11 matches isaac lab.\n",
    "model_proto, _ = tf2onnx.convert.from_keras(tf_policy_network, input_signature=spec, opset=11, output_path=output_path)\n",
    "\n",
    "# Run inference with ONNX Runtime\n",
    "output_names = ['continuous_actions']\n",
    "providers = ['CPUExecutionProvider']\n",
    "m = rt.InferenceSession(output_path, providers=providers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_input = {\n",
    "  'obs': np.ones((1, obs_size[\"state\"][0]), dtype=np.float32)\n",
    "}\n",
    "# Prepare inputs for ONNX Runtime\n",
    "onnx_pred = m.run(output_names, onnx_input)[0][0]\n",
    "\n",
    "print(\"ONNX prediction:\", onnx_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_input = {\n",
    "    'state': jp.ones(obs_size[\"state\"]),\n",
    "    'privileged_state': jp.zeros(obs_size[\"privileged_state\"])\n",
    "}\n",
    "jax_pred, _ = inference_fn(test_input, jax.random.PRNGKey(0))\n",
    "print(jax_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "print(onnx_pred.shape)\n",
    "print(tensorflow_pred.shape)\n",
    "print(jax_pred.shape)\n",
    "plt.plot(onnx_pred, label='onnx')\n",
    "plt.plot(tensorflow_pred, label='tensorflow')\n",
    "plt.plot(jax_pred, label='jax')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env_cfg = locomotion.get_default_config(env_name)\n",
    "# env = locomotion.load(env_name, config=env_cfg)\n",
    "# jit_reset = jax.jit(env.reset)\n",
    "# jit_step = jax.jit(env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Test the policy.\n",
    "\n",
    "# # env_cfg = locomotion.get_default_config(env_name)\n",
    "# # env_cfg.init_from_crouch = 0.0\n",
    "# # env = locomotion.load(env_name, config=env_cfg)\n",
    "# # env_cfg = manipulation.get_default_config(env_name)\n",
    "# # env = manipulation.load(env_name, config=env_cfg)\n",
    "# # jit_reset = jax.jit(env.reset)\n",
    "# # jit_step = jax.jit(env.step)\n",
    "\n",
    "# x = 0.8\n",
    "# y = 0.0\n",
    "# yaw = 0.3\n",
    "# command = jp.array([x, y, yaw])\n",
    "# # actions = []\n",
    "\n",
    "# states = [state := jit_reset(jax.random.PRNGKey(555))]\n",
    "# state.info[\"command\"] = command\n",
    "# for _ in range(env_cfg.episode_length):\n",
    "#   onnx_input = {'obs': np.array(state.obs[\"state\"].reshape(1, -1))}\n",
    "#   action = m.run(output_names, onnx_input)[0][0]\n",
    "#   state = jit_step(state, jp.array(action))\n",
    "#   state.info[\"command\"] = command\n",
    "#   states.append(state)\n",
    "#   # actions.append(state.info[\"motor_targets\"])\n",
    "#   # actions.append(action)\n",
    "#   if state.done:\n",
    "#     print(\"Unexpected termination.\")\n",
    "#     break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import mediapy as media\n",
    "# fps = 1.0 / env.dt\n",
    "\n",
    "# frames = env.render(\n",
    "#     states,\n",
    "#     camera=\"track\",\n",
    "#     width=640*2,\n",
    "#     height=480*2,\n",
    "# )\n",
    "# media.show_video(frames, fps=fps, loop=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
